# SPDX-FileCopyrightText: 2022-2025 Institute of Computing Technology, Chinese Academy of Sciences
# SPDX-License-Identifier: AGPL-3.0-or-later

import argparse
import time
from typing import Dict, List
import torch
import numpy as np
import sys
import os

import sys
sys.path.append('..')
import models
from utils.view_pt import select_weight_file
from anypacking.quant_module import HWGQ, QuantConv2d, ImageInputQ, QuantLinear

class ConvParam: ...

def write_hls_config(model_param, path):
    name_mapping = {
        'k': 'K',
        #'s': 'S',
        #'p': 'P',
        'ich': 'IFM_CH',
        'irow': 'IFM_ROW',
        'icol': 'IFM_COL',
        'och': 'OFM_CH',
        'orow': 'OFM_ROW',
        'ocol': 'OFM_COL',
        'abit': 'IN_BIT',
        'wbit': 'W_BIT',
        'incbit': 'INC_BIT',
        'biasbit': 'BIAS_BIT',
        'simd': 'SIMD',
        'pe': 'PE',
        'lshift': 'L_SHIFT'
    }
    content = f'''/********************************************************************************
* Filename: config.h
* Date: {time.ctime()}
* Description: This file is generated by {parser.prog}
*   ptfilename: {opt.weight} 
********************************************************************************/

#ifndef _CONFIG_H_
#define _CONFIG_H_

'''
    for n, conv_param in enumerate(model_param):
        content += f'// {conv_param.type}_{n}\n'
        for k, v in name_mapping.items():
            if hasattr(conv_param, k): # e.g. conv_last has no incbit
                content += f'#define {conv_param.type.upper()}_{n}_{v} {getattr(conv_param, k)}\n'
        content += '\n'
    content += '#endif'

    with open(path + 'config.h', 'w') as f:
        print(content, file=f)

def extract_model(in_shape):
    model_param: List[ConvParam] = []
    feature_map_shape = in_shape
    conv_cnt = 0
    conv_cur = None
    for sub_module in model.modules():
        # expect [QAct] -> [Pooling] -> Conv -> [BN] -> [Pooling], state machine mode
        if isinstance(sub_module, HWGQ) or isinstance(sub_module, ImageInputQ):
            print('  Detected ActQ Layer', end='')
            if conv_cur is None: conv_cur = ConvParam()

            conv_cur.abit = sub_module.bit
            conv_cur.astep = sub_module.step
            
            conv_cur.actq_class = type(sub_module).__name__
            print(f', abit {conv_cur.abit}, astep {conv_cur.astep}, class {conv_cur.actq_class}')

            if conv_cnt: # previous.obit = cur.abit
                model_param[conv_cnt-1].obit = conv_cur.abit
                model_param[conv_cnt-1].ostep = conv_cur.astep
            
        elif isinstance(sub_module, torch.nn.Conv2d):
            if conv_cur is None: conv_cur = ConvParam()
            conv_cur.n = conv_cnt
            print('Extract conv_%d'%conv_cnt, end='')

            conv_cur.k = sub_module.kernel_size[0]
            conv_cur.s = sub_module.stride[0]
            conv_cur.p = sub_module.padding[0]
            conv_cur.ich = sub_module.in_channels
            conv_cur.och = sub_module.out_channels
            conv_cur.irow = feature_map_shape[1]
            conv_cur.icol = feature_map_shape[2]
            
            feature_map_shape[0] = sub_module.out_channels
            feature_map_shape[1] = (feature_map_shape[1] + 2 * sub_module.padding[0] - sub_module.kernel_size[0]) // sub_module.stride[0] + 1
            feature_map_shape[2] = (feature_map_shape[2] + 2 * sub_module.padding[0] - sub_module.kernel_size[0]) // sub_module.stride[0] + 1
            conv_cur.orow = feature_map_shape[1]
            conv_cur.ocol = feature_map_shape[2]

            assert sub_module.bias is None, 'inner conv has no bias in this model'
            if isinstance(sub_module, QuantConv2d): # New quant
                conv_cur.wbit = sub_module.bit
                conv_cur.w, conv_cur.wstep = sub_module.export_quant() # wstep is not QuantConv2d.step because of alpha
            else:
                raise NotImplementedError(sub_module)
            print(', ich {ich}, och {och}, irow {irow}, icol {icol}, ksp {k}{s}{p}, wbit {wbit}, wstep {wstep}'.format(**vars(conv_cur)))
            
            conv_cur.type = 'conv'
            model_param.append(conv_cur)
            conv_cur = None
            conv_cnt += 1

        elif isinstance(sub_module, torch.nn.Linear):
            if conv_cur is None: conv_cur = ConvParam() # TODO: independent type for linear layer
            conv_cur.n = conv_cnt
            print('Extract layer %d (linear layer)'%conv_cnt, end='')

            conv_cur.ich = sub_module.in_features
            conv_cur.och = sub_module.out_features
            conv_cur.irow = feature_map_shape[1]
            conv_cur.icol = feature_map_shape[2]
            
            if sub_module.bias is not None:
                conv_cur.convbias = sub_module.bias.detach().numpy()
                print(', +bias', end='')
            
            if isinstance(sub_module, QuantLinear): # New quant
                conv_cur.wbit = sub_module.bit
                conv_cur.w, conv_cur.wstep = sub_module.export_quant() # wstep is not QuantLinear.step because of alpha

            print(', ich {ich}, och {och}, wbit {wbit}, wstep {wstep}'.format(**vars(conv_cur)))
            
            conv_cur.type = 'linear'
            model_param.append(conv_cur)
            conv_cur = None
            conv_cnt += 1
        
        elif isinstance(sub_module, torch.nn.BatchNorm2d):
            print('  Detected BatchNorm2d')
            gamma = sub_module.weight
            beta = sub_module.bias
            mean = sub_module.running_mean
            var = sub_module.running_var
            eps = sub_module.eps
            
            model_param[-1].bn_w = (gamma / (torch.sqrt(var + eps))).detach().numpy()
            model_param[-1].bn_b = (beta - (mean / (torch.sqrt(var + eps)) * gamma)).detach().numpy()

        elif isinstance(sub_module, torch.nn.MaxPool2d):
            print('  Detected MaxPool2d')
            feature_map_shape[1] = feature_map_shape[1] // sub_module.kernel_size
            feature_map_shape[2] = feature_map_shape[2] // sub_module.kernel_size
    
    assert hasattr(model_param[0], 'abit')

    return model_param

def process_batchnorm(model_param):
    '''process_batchnorm(model_param)
    Merge wstep, astep, ostep scale into batchnorm, then quantize. 

    Method:
    Define MAC = Conv(w, a), out = MAC*BN_w + BN_b,
    wq = w/wstep, aq = a/astep, MACq = MAC/MACstep, outq = out/ostep.

    outq = (MAC*BN_w + BN_b) / ostep
         = MACq * (MACstep/ostep)*BN_w + BN_b/ostep
         = MACq *     inc_raw          + bias_raw
    next layer activation a' = ActQ(out), i.e. a'q = clip(round(outq))

    Quantiaztion of inc_raw & bias_raw: 
    outq_real = round((MACq*round(inc_raw*scale) + round(bias_raw*scale)) / scale)         ; where scale=2**T
              = (MACq*round(inc_raw*scale) + round(bias_raw*scale) + 0.5 * scale) // scale ; div floor
              = (MACq*        inc          +         bias          +  2**(T-1)  ) >> T     ; [!] the 2**(T-1) bias is done by hls code

    Params:
    T = (wbit-1)+abit+lshift  # This comes from dorefa quant, not optimal
    MBIT = wbit+abit+ceil(log2(sum_number))
    incbit = len(bit(inc)); biasbit = len(bit(bias))
    larger lshift is better, but MBIT+incbit<48
    '''
    lshift = 16

    for conv in model_param[:-1]:
        print(f'Process bn_{conv.n}, shape {conv.bn_w.shape},', end = ' ')

        # Merge step to BN
        conv.lshift = lshift
        MACstep = conv.wstep * conv.astep
        ostep = conv.ostep
        inc_raw = conv.bn_w * MACstep / ostep
        bias_raw = conv.bn_b / ostep
        conv.inc_raw = inc_raw
        conv.bias_raw = bias_raw

        # Quantization
        T = lshift+conv.wbit+conv.abit-1
        conv.inc = np.round(inc_raw * 2**T).astype(np.int64)
        conv.bias = np.round(bias_raw * 2**T).astype(np.int64)
        conv.lshift_T = T
        # Get bitlength
        bitlength = lambda x: 1 + int(np.abs(x).max()).bit_length()
        conv.incbit = bitlength(conv.inc)
        conv.biasbit = bitlength(conv.bias)
        print(f'incbit {conv.incbit}, biasbit {conv.biasbit}, lshift_T {conv.lshift_T}')
    
    conv_last = model_param[-1] # process lastbias
    conv_last.inc = None
    conv_last.div = 1/(conv_last.wstep * conv_last.astep)
    conv_last.bias = np.round(conv_last.convbias * conv_last.div).astype(np.int64)
    conv_last.biasbit = bitlength(conv_last.bias)
    print(f'conv_last biasbit {conv_last.biasbit}, div {conv_last.div}')

def reorder_weight(model_param, layers_simd, layers_pe):
    '''reorder_weight(model_param)
    Reorder array for hlscode.
    '''

    for conv in model_param:
        if conv.type == 'linear': #new reorder
            pe_l = 1
            simd_l = 1
            in_pe_l = 8
            w = conv.w.reshape(10, -1, 4, 4)
            w = w.reshape(10 // (2 * pe_l), pe_l, 2, 256 // in_pe_l, in_pe_l // simd_l, simd_l, 4, 4)  #[OUT_CH/2PE, PE, 2, IN_CH/IN_PE, IN_PE/SIMD, SIMD, H, W]
            w = w.transpose(1, 6, 3, 7, 0, 4, 5, 2)                                                    #[PE, H, IN_CH/IN_PE, W, OUT_CH/2PE, IN_PE/SIMD, SIMD, 2]
            w = w.reshape(w.shape[0], w.shape[1], w.shape[2], w.shape[3], w.shape[4], w.shape[5], -1)  #[PE, H, IN_CH/IN_PE, W, OUT_CH/2PE, IN_PE/SIMD, SIMD * 2]
            print(w.shape)
            conv.w = w
            continue

        print(f'Reorder conv_{conv.n}, w {conv.w.shape}', end='')
        conv.simd = layers_simd[conv.n]
        conv.pe = layers_pe[conv.n]

        # process batchnorm
        if conv.inc is not None:
            conv.inc = conv.inc.reshape(conv.och//conv.pe, conv.pe).T
        if conv.bias is not None:
            conv.bias = conv.bias.reshape(conv.och//conv.pe, conv.pe).T
        
        # process conv weight
        w = conv.w    # [och, ich, kr, kc]
        assert conv.och%conv.pe == 0, f"conv_{conv.n}, och {conv.och}, pe {conv.pe}"
        assert conv.k*conv.ich%conv.simd == 0, f"conv_{conv.n}, ich {conv.ich}, k {conv.k}, simd {conv.simd}"

        # if conv.n==0: # first layer is different
        #    w = w.transpose(0, 2, 3, 1) # [och, kr, kc, ich]
        # else:
        w = w.transpose(0, 3, 2, 1) # [och, kc, kr, ich]

        w = w.reshape(conv.och//conv.pe, conv.pe, conv.k, conv.k*conv.ich//conv.simd, conv.simd)
        w = w.transpose(1,2,0,3,4) # [pe, k, och/pe, k*ich/simd, simd]
        w = w.reshape(conv.pe, conv.k, -1, conv.simd) # hls format [pe, k, och/pe*k*ich/simd, simd]

        if conv.k == 1: # kernel size=1
            w = w.reshape(conv.pe, -1, conv.simd)
        print(' ->', w.shape)

        conv.w = w

def print_ndarray_recursion(arr, str_func=str, file=sys.stdout, stop=0):
    if not hasattr(arr, '__iter__') or len(arr.shape) == stop:
        print(str_func(arr), file=file, end='')
        return
    ends = '' if (len(arr.shape)==stop+1) else '\n'
    print('{', file=file, end='')
    for i, item in enumerate(arr):
        print_ndarray_recursion(item, str_func, file, stop)
        if i!=len(arr)-1: print(',', file=file, end=ends)
    print(ends+'}', file=file, end='')

def write_hls_linearlayer(layer, f):
    n = layer.n
    print(f"// layer: {n}, wbit: {layer.wbit}", file=f)
    hex_str = lambda x: '"' + hex(x) + '"'
    print(f"const ap_int<{layer.wbit}> linear_{n}_w[{layer.och}][{layer.ich}]=", file=f)
    print_ndarray_recursion(layer.w, hex_str, f)
    print(';', file=f)
    
    if layer.bias is not None:
        print(f"const ap_int<{layer.biasbit}> linear_{n}_bias[{layer.och}]=", file=f)
        print_ndarray_recursion(layer.bias, hex_str, f)
        print(';', file=f)

def write_hls_linearlayer_reorder(layer, d0, d1, d2, d3, d4, d5, d6, f):
    n = layer.n
    print(f"// layer: {n}, wbit: {layer.wbit}", file=f)
    hex_str = lambda x: '"' + hex(x) + '"'
    def pack1d_str(arr): # x: 1d-array
        x = 0
        # print(arr.shape)
        for v in arr[::-1]: # [!] reverse simd pack, it is related to hls implemention
            v = int(v) # use python bignumber, not np.int
            assert -1<<layer.wbit-1 <= v < 1<<layer.wbit-1, f'got v={v} while wbit={layer.wbit}'
            x=(x<<layer.wbit) + (v&(2**layer.wbit-1))
        return hex_str(x)
    print(f"const ap_uint<{layer.wbit * d6}> linear_{n}_w[{d0}][{d1}][{d2}][{d3}][{d4}][{d5}]=", file=f)
    print_ndarray_recursion(layer.w, pack1d_str, f, stop=1)
    print(';', file=f)
    
    if layer.bias is not None:
        print(f"const ap_int<{layer.biasbit}> linear_{n}_bias[{layer.och}]=", file=f)
        print_ndarray_recursion(layer.bias, hex_str, f)
        print(';', file=f)

def write_hls_weights(model_param, path):
    '''write_hls_weights(model_param, path)
    Write hls weights+inc+bias array code according to numpy shape.
    '''
    f = open(path + 'weights.hpp', 'w')

    print(f'''/********************************************************************************
* Filename: weights.hpp
* Date: {time.ctime()}
* Description: This file is generated by {parser.prog}
*   ptfilename: {opt.weight} 
********************************************************************************/

#ifndef _WEIGHTS_HPP_
#define _WEIGHTS_HPP_
#include <ap_int.h>
''', file=f)

    for conv in model_param:
        if conv.type == 'linear':
            pe_pr = conv.w.shape[0]
            h_pr = conv.w.shape[1]
            inch_inpe_pr = conv.w.shape[2]
            w_pr = conv.w.shape[3]
            outch_2pe_pr = conv.w.shape[4]
            inpe_simd_pr = conv.w.shape[5]
            simd2_pr = conv.w.shape[6]
            write_hls_linearlayer_reorder(conv, pe_pr, h_pr, inch_inpe_pr, w_pr, outch_2pe_pr, inpe_simd_pr, simd2_pr, f)
            continue

        n = conv.n
        print(f"Write conv_{n} weight, pe {conv.pe}, simd {conv.simd}, wbit {conv.wbit}")
        print(f"// layer: {n}, PE: {conv.pe}, SIMD: {conv.simd}, wbit: {conv.wbit}", file=f)

        # print conv weight,  merge [SIMD] value into one ap_uint
        if conv.k>1:
            print(f"const ap_uint<{conv.wbit * conv.simd}> conv_{n}_w[{conv.pe}][{conv.k}][{conv.w.shape[2]}]=", file=f)
        else:
            print(f"const ap_uint<{conv.wbit * conv.simd}> conv_{n}_w[{conv.pe}][{conv.w.shape[1]}]=", file=f)
        hex_str = lambda x: '"' + hex(x) + '"'
        def pack1d_str(arr): # x: 1d-array
            x = 0
            for v in arr[::-1]: # [!] reverse simd pack, it is related to hls implemention
                v = int(v) # use python bignumber, not np.int
                assert -1<<conv.wbit-1 <= v < 1<<conv.wbit-1, f'got v={v} while wbit={conv.wbit}'
                x=(x<<conv.wbit) + (v&(2**conv.wbit-1))
            return hex_str(x)
        print_ndarray_recursion(conv.w, pack1d_str, f, stop=1)
        print(';', file=f)

        # print inc, bias
        if conv.inc is not None:
            print(f"const ap_int<{conv.incbit}> conv_{n}_inc[{conv.pe}][{conv.och//conv.pe}]=", file=f)
            print_ndarray_recursion(conv.inc, hex_str, f)
            print(';', file=f)
        if conv.bias is not None:
            print(f"const ap_int<{conv.biasbit}> conv_{n}_bias[{conv.pe}][{conv.och//conv.pe}]=", file=f)
            print_ndarray_recursion(conv.bias, hex_str, f)
            print(';', file=f)
    
    print('#endif', file=f)
    f.close()

def adjust_weight(model_param):
    # special_wa_bit = ((5,6), (7,3)) # These packing can't quantize to -2**(wbit-1)
    special_wa_bit = ((4, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), (7, 2), (7, 3)) # These packing can't quantize to -2**(wbit-1)
    for conv in model_param:
        if (conv.wbit, conv.abit) in special_wa_bit:
            print(f'Adjust conv_{conv.n} wbit={conv.wbit}')
            conv.w = np.maximum(conv.w, -2**(conv.wbit-1)+1)

if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-w', '--weight', default=None, help='.pt file name in ./weights/')
    parser.add_argument('-m', '--model', default='VGG_tiny_FixQ', help = 'model class name in models.py')
    parser.add_argument('-c', '--config-simd-pe', default='config_simd_pe', help = '.txt file in ./hls/')
    opt = parser.parse_args()
    if opt.weight is None: opt.weight = select_weight_file()

    simd_pe = np.loadtxt('hls/'+opt.config_simd_pe+'.txt', dtype=int, skiprows=1)
    dir_output = 'hls/' + opt.weight + '/'
    if not os.path.exists(dir_output): os.makedirs(dir_output)

    # load model and state_dict
    ptfile:Dict = torch.load('weights/' + opt.weight + '.pt', map_location='cpu')
    model = getattr(models, opt.model)(**ptfile.setdefault('model_params', {}))
    model.load_state_dict(ptfile['model'], strict = False)

    # processs
    model_param = extract_model([1, 32, 32])
    adjust_weight(model_param)
    process_batchnorm(model_param) # get bn param before write hls config
    torch.save(model_param, dir_output + 'model_param.pkl')
    
    reorder_weight(model_param, simd_pe[:,0], simd_pe[:,1]) # get pe, simd param before write hls config
    write_hls_config(model_param, dir_output)
    write_hls_weights(model_param, dir_output)
